Skip to content

Add GradientAccumulation utility for SupervisedTrainer#8763

Open
aymuos15 wants to merge 7 commits intoProject-MONAI:devfrom
aymuos15:feat/grad-accum-supervisedtrainer
Open

Add GradientAccumulation utility for SupervisedTrainer#8763
aymuos15 wants to merge 7 commits intoProject-MONAI:devfrom
aymuos15:feat/grad-accum-supervisedtrainer

Conversation

@aymuos15
Copy link
Contributor

@aymuos15 aymuos15 commented Mar 3, 2026

Summary

  • Adds GradientAccumulation callable class in monai.engines.utils for use as iteration_update in SupervisedTrainer, enabling gradient accumulation over multiple mini-batches to simulate larger effective batch sizes on memory-constrained hardware
  • Follows the callable-class iteration_update pattern established by Interaction in monai.apps.deepedit (as referenced by @wyli in Add gradient accumulation logic to SupervisedTrainer #6101)
  • All IterationEvents fire every mini-batch, so existing handlers are unaffected
  • Epoch boundary flush ensures no gradients are silently discarded when epoch_length % accumulation_steps != 0
  • Mixed-precision (GradScaler) support included

Closes #6100
Supersedes #6101

Usage

from monai.engines import SupervisedTrainer, GradientAccumulation

trainer = SupervisedTrainer(
    ...,
    iteration_update=GradientAccumulation(accumulation_steps=4),
)

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • In-line docstrings updated.

Test plan

  • Input validation (zero, negative, float, string)
  • Passthrough when accumulation_steps=1
  • zero_grad / optimizer.step suppression patterns verified across full epochs
  • Epoch boundary flush when epoch_length not divisible by accumulation_steps
  • Iterable dataset (epoch_length=None) — no epoch flush
  • Patching/restoration of all engine methods after each call
  • Restoration after exception (try/finally)
  • GradScaler patching when step suppressed, not patched when stepping
  • No scaler attribute / scaler=None edge cases
  • Batch data forwarded correctly to _iteration
  • Output loss is unscaled (original value for loggers/metrics)
  • Integration: gradient equivalence with manual accumulation (requires ignite)
  • Integration: epoch boundary flush equivalence (requires ignite)
  • Integration: multi-epoch correctness (requires ignite)

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 3, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This pull request adds gradient accumulation functionality to MONAI's SupervisedTrainer. A new parameter accumulation_steps is added to the trainer initialization with validation ensuring positive integer values. During training iterations, gradients are scaled by dividing loss by the accumulation step count, with optimizer updates conditioned on whether the specified accumulation threshold is reached or an epoch boundary occurs. The loss remains unscaled in the trainer's output state. A comprehensive test suite validates input validation, passthrough behavior when accumulation is disabled, gradient equivalence between accumulated and non-accumulated training paths, epoch boundary handling, and multi-epoch consistency.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~18 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 70.59% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title 'Add GradientAccumulation utility for SupervisedTrainer' clearly and concisely summarizes the main change.
Description check ✅ Passed Description covers objectives, implementation details, usage examples, and comprehensive test plans aligned with the template.
Linked Issues check ✅ Passed Changes fully meet issue #6100 objectives: gradient accumulation functionality, preserved MONAI/Ignite events, edge case handling, mixed-precision support, clean API.
Out of Scope Changes check ✅ Passed All changes directly support gradient accumulation: trainer modifications, utility implementation, and comprehensive test coverage. No extraneous modifications detected.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/engines/test_gradient_accumulation.py (1)

339-350: Drop the unused init_weight from the helper return.
It is not consumed by callers, so removing it tightens the helper contract and avoids dead unpacks downstream.

♻️ Proposed cleanup
@@
-def _make_model_pair(lr):
+def _make_model_pair(lr):
@@
-    return ref_model, test_model, ref_opt, test_opt, init_weight
+    return ref_model, test_model, ref_opt, test_opt
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` around lines 339 - 350, The
helper _make_model_pair currently returns an unused init_weight which tightens
its contract unnecessarily; remove the creation or cloning of init_weight from
inside _make_model_pair (or keep the local init copy only to set test_model
weights) and update the return tuple from _make_model_pair to return only
(ref_model, test_model, ref_opt, test_opt), then update any callers that unpack
the result to stop expecting the fifth value so there are no dead unpacks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 339-350: The helper _make_model_pair currently returns an unused
init_weight which tightens its contract unnecessarily; remove the creation or
cloning of init_weight from inside _make_model_pair (or keep the local init copy
only to set test_model weights) and update the return tuple from
_make_model_pair to return only (ref_model, test_model, ref_opt, test_opt), then
update any callers that unpack the result to stop expecting the fifth value so
there are no dead unpacks.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 894068a and 597e086.

📒 Files selected for processing (3)
  • monai/engines/__init__.py
  • monai/engines/utils.py
  • tests/engines/test_gradient_accumulation.py

…ject-MONAI#6100)

Closes Project-MONAI#6100

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
@aymuos15 aymuos15 force-pushed the feat/grad-accum-supervisedtrainer branch from 597e086 to 1db8cc1 Compare March 3, 2026 11:08
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
tests/engines/test_gradient_accumulation.py (1)

105-105: Consider marking intentionally-unused bindings with _ prefixes.

This keeps tests clear while avoiding avoidable lint noise.

🧹 Optional cleanup
-        def fake_iteration(eng, batch):
+        def fake_iteration(eng, _batch):
@@
-        def check_scaler(eng, batch):
+        def check_scaler(eng, _batch):
@@
-        def fake_iteration(*args, **kwargs):
+        def fake_iteration(*_args, **_kwargs):
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr)
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr)
@@
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt, _init_weight = _make_model_pair(lr)

Also applies to: 188-188, 234-234, 257-257, 287-287, 318-318

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` at line 105, The test defines
callback functions like fake_iteration(eng, batch) with parameters that are
intentionally unused; update these function signatures (and the other
occurrences at the same pattern) to mark unused parameters with leading
underscores (e.g., _eng, _batch or _batch_idx) so linters know the bindings are
intentionally unused—search for the function name fake_iteration and the similar
callback definitions at the other noted locations and rename the unused
parameters with _ prefixes.
monai/engines/utils.py (1)

366-368: Align new definitions with Google-style docstring sections.

_noop, __init__, and __repr__ should include explicit Args/Returns (and Raises where applicable) sections to match repo docstring policy.

♻️ Suggested docstring adjustments
 def _noop(*args: Any, **kwargs: Any) -> None:
-    """No-op callable used to suppress optimizer/scaler methods during gradient accumulation."""
+    """No-op callable used to suppress optimizer/scaler methods.
+
+    Args:
+        *args: Ignored positional arguments.
+        **kwargs: Ignored keyword arguments.
+
+    Returns:
+        None.
+    """

 class GradientAccumulation:
@@
     def __init__(self, accumulation_steps: int = 2) -> None:
+        """Initialize gradient accumulation behavior.
+
+        Args:
+            accumulation_steps: Number of mini-batches to accumulate before stepping.
+
+        Raises:
+            ValueError: If `accumulation_steps` is not a positive integer.
+        """
         if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
@@
     def __repr__(self) -> str:
+        """Return a debug-friendly representation.
+
+        Returns:
+            String representation with configured accumulation steps.
+        """
         return f"GradientAccumulation(accumulation_steps={self.accumulation_steps})"

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

Also applies to: 405-413

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` around lines 366 - 368, Add Google-style docstring
sections to the new definitions: for _noop include an "Args" section describing
*args and **kwargs and a "Returns" section noting it returns None; for the class
__init__ add an "Args" section for each parameter and a "Returns" section if
applicable (or state None) and an optional "Raises" section if it can raise
exceptions; for __repr__ add a "Returns" section describing the returned str.
Update the docstrings in functions/methods named _noop, __init__, and __repr__
to follow the repo's Google-style (Args, Returns, and Raises where needed) and
mirror the format used elsewhere in the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@monai/engines/utils.py`:
- Around line 366-368: Add Google-style docstring sections to the new
definitions: for _noop include an "Args" section describing *args and **kwargs
and a "Returns" section noting it returns None; for the class __init__ add an
"Args" section for each parameter and a "Returns" section if applicable (or
state None) and an optional "Raises" section if it can raise exceptions; for
__repr__ add a "Returns" section describing the returned str. Update the
docstrings in functions/methods named _noop, __init__, and __repr__ to follow
the repo's Google-style (Args, Returns, and Raises where needed) and mirror the
format used elsewhere in the file.

In `@tests/engines/test_gradient_accumulation.py`:
- Line 105: The test defines callback functions like fake_iteration(eng, batch)
with parameters that are intentionally unused; update these function signatures
(and the other occurrences at the same pattern) to mark unused parameters with
leading underscores (e.g., _eng, _batch or _batch_idx) so linters know the
bindings are intentionally unused—search for the function name fake_iteration
and the similar callback definitions at the other noted locations and rename the
unused parameters with _ prefixes.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 597e086 and 1db8cc1.

📒 Files selected for processing (3)
  • monai/engines/__init__.py
  • monai/engines/utils.py
  • tests/engines/test_gradient_accumulation.py

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (1)
monai/engines/utils.py (1)

413-413: Widen batchdata type hint in __call__.

batchdata: dict[str, Any] is tighter than common trainer inputs. Consider Any to avoid misleading static typing for tuple/list batch payloads.

Proposed fix
-    def __call__(self, engine: Any, batchdata: dict[str, Any]) -> dict:
+    def __call__(self, engine: Any, batchdata: Any) -> dict:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` at line 413, The type hint for the __call__ method
currently restricts batchdata to dict[str, Any], which is too narrow for
trainers that pass tuples/lists; change the signature of __call__ to accept
batchdata: Any (or more permissive Union types) so it can handle dict, tuple,
list, etc., and update any related type annotations/comments in the same
function (named __call__) and its callers to reflect the broader type.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/engines/utils.py`:
- Around line 406-407: The validation for accumulation_steps currently allows
booleans because bool is an int subclass; update the check in
monai.engines.utils (the accumulation_steps validation) to explicitly reject
bools — e.g., require type(accumulation_steps) is int or add "and not
isinstance(accumulation_steps, bool)" to the isinstance check — and keep the
existing lower-bound check (accumulation_steps < 1) so True/False no longer pass
validation.

In `@tests/engines/test_gradient_accumulation.py`:
- Line 91: The test callback function parameters that are intentionally unused
(for example in function fake_iteration) are triggering ARG001; rename those
parameters by prefixing with an underscore (e.g., change def fake_iteration(eng,
batch): to def fake_iteration(_eng, _batch):) and apply the same pattern to the
other callbacks mentioned (the occurrences around the other reported locations)
so unused arguments are clearly marked and lint-clean.
- Line 240: The helper _make_model_pair currently returns a third value
init_weight that callers (tests in tests/engines/test_gradient_accumulation.py)
unpack but never use; remove this unused plumbing by changing _make_model_pair
to return only (ref_model, test_model, ref_opt, test_opt) and update all call
sites (e.g., the unpack at the shown line and similar occurrences at the other
locations) to stop expecting init_weight — adjust any tuple unpacking in the
tests to four variables matching the function's new signature.

---

Nitpick comments:
In `@monai/engines/utils.py`:
- Line 413: The type hint for the __call__ method currently restricts batchdata
to dict[str, Any], which is too narrow for trainers that pass tuples/lists;
change the signature of __call__ to accept batchdata: Any (or more permissive
Union types) so it can handle dict, tuple, list, etc., and update any related
type annotations/comments in the same function (named __call__) and its callers
to reflect the broader type.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between 1db8cc1 and a3eca14.

📒 Files selected for processing (2)
  • monai/engines/utils.py
  • tests/engines/test_gradient_accumulation.py

Comment on lines +406 to +407
if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
Copy link
Contributor

@coderabbitai coderabbitai bot Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Reject boolean values for accumulation_steps.

True currently passes validation because bool is an int subclass, so invalid config can silently map to 1.

Proposed fix
-        if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
+        if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1:
             raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
🧰 Tools
🪛 Ruff (0.15.2)

[warning] 407-407: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` around lines 406 - 407, The validation for
accumulation_steps currently allows booleans because bool is an int subclass;
update the check in monai.engines.utils (the accumulation_steps validation) to
explicitly reject bools — e.g., require type(accumulation_steps) is int or add
"and not isinstance(accumulation_steps, bool)" to the isinstance check — and
keep the existing lower-bound check (accumulation_steps < 1) so True/False no
longer pass validation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We shouldn't check the type of accumulation_steps at all since the type annotation is present, only the correct value range should be enforced.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericspod, understood — I'll withdraw that suggestion. The type annotation is the right place to express the type contract, and the runtime guard should only enforce the value range (accumulation_steps < 1).


📝 Learnings were identified and not saved (knowledge base disabled). Enable

acc_steps, lr = 4, 0.1
batches = [{CommonKeys.IMAGE: torch.randn(1, 4), CommonKeys.LABEL: torch.randn(1, 1)} for _ in range(acc_steps)]

ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused init_weight plumbing from model-pair helper.

init_weight is returned/unpacked but never used by tests.

Proposed fix
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
...
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)
...
-        ref_model, test_model, ref_opt, test_opt, init_weight = _make_model_pair(lr)
+        ref_model, test_model, ref_opt, test_opt = _make_model_pair(lr)

-def _make_model_pair(lr):
+def _make_model_pair(lr):
     """Create a reference and test model pair with identical initial weights."""
     ref_model = nn.Linear(4, 1, bias=False)
     init_weight = ref_model.weight.data.clone()
@@
-    return ref_model, test_model, ref_opt, test_opt, init_weight
+    return ref_model, test_model, ref_opt, test_opt

Also applies to: 271-271, 303-303, 328-339

🧰 Tools
🪛 Ruff (0.15.2)

[warning] 240-240: Unpacked variable init_weight is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` at line 240, The helper
_make_model_pair currently returns a third value init_weight that callers (tests
in tests/engines/test_gradient_accumulation.py) unpack but never use; remove
this unused plumbing by changing _make_model_pair to return only (ref_model,
test_model, ref_opt, test_opt) and update all call sites (e.g., the unpack at
the shown line and similar occurrences at the other locations) to stop expecting
init_weight — adjust any tuple unpacking in the tests to four variables matching
the function's new signature.

aymuos15 added 2 commits March 3, 2026 15:00
Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
monai/engines/utils.py (1)

406-407: ⚠️ Potential issue | 🟡 Minor

Reject bool for accumulation_steps.

True currently passes because bool is an int subclass, so invalid config can silently map to 1.

Proposed fix
-        if not isinstance(accumulation_steps, int) or accumulation_steps < 1:
+        if isinstance(accumulation_steps, bool) or not isinstance(accumulation_steps, int) or accumulation_steps < 1:
             raise ValueError(f"`accumulation_steps` must be a positive integer, got {accumulation_steps!r}.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/engines/utils.py` around lines 406 - 407, The current validation for
accumulation_steps accepts bool because bool is an int subclass; update the
check so booleans are rejected — e.g., replace the
isinstance(accumulation_steps, int) test with a stricter type check (such as
type(accumulation_steps) is int or add an explicit not
isinstance(accumulation_steps, bool) condition) so that accumulation_steps must
be a genuine int and >= 1; adjust the ValueError message if needed to reflect
the stricter type requirement and keep the existing check for accumulation_steps
< 1.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 28-29: The test data in INVALID_ACCUMULATION_STEPS misses a
boolean edge case; update the tuples in INVALID_ACCUMULATION_STEPS (and the
similar list at lines 58-63 referenced in the comment) to include True as an
invalid input (e.g., add (True,) alongside (0,), (-1,), (2.5,), ("2",)) so the
test suite covers the bool-as-int validation bug for the functions that consume
INVALID_ACCUMULATION_STEPS.

---

Duplicate comments:
In `@monai/engines/utils.py`:
- Around line 406-407: The current validation for accumulation_steps accepts
bool because bool is an int subclass; update the check so booleans are rejected
— e.g., replace the isinstance(accumulation_steps, int) test with a stricter
type check (such as type(accumulation_steps) is int or add an explicit not
isinstance(accumulation_steps, bool) condition) so that accumulation_steps must
be a genuine int and >= 1; adjust the ValueError message if needed to reflect
the stricter type requirement and keep the existing check for accumulation_steps
< 1.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to Reviews -> Disable Knowledge Base setting

📥 Commits

Reviewing files that changed from the base of the PR and between a3eca14 and 53c5dc5.

📒 Files selected for processing (2)
  • monai/engines/utils.py
  • tests/engines/test_gradient_accumulation.py

Comment on lines +28 to +29
INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add explicit bool invalid-input coverage.

This suite misses True, which is the key edge case for the bool-as-int validation bug.

Proposed fix
-INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)]
+INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",), (True,), (False,)]

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

Also applies to: 58-63

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` around lines 28 - 29, The test
data in INVALID_ACCUMULATION_STEPS misses a boolean edge case; update the tuples
in INVALID_ACCUMULATION_STEPS (and the similar list at lines 58-63 referenced in
the comment) to include True as an invalid input (e.g., add (True,) alongside
(0,), (-1,), (2.5,), ("2",)) so the test suite covers the bool-as-int validation
bug for the functions that consume INVALID_ACCUMULATION_STEPS.

@ericspod
Copy link
Member

ericspod commented Mar 7, 2026

Hi @aymuos15 thanks for looking into this which had fallen by the wayside a bit. I've looked over your solution and it's similar to what was proposed before. I feel that it's not quite the right way of going about it as it relies on the members of a lot of objects, so relies on the assumptions about the structure of those objects a lot. The way I would go about this feature is to modify SupervisedTrainer directly with a new constructor argument and object member accumulation_steps to serve the same purpose as it does here. You can then modify the _iteration method to selectively do the optimisation step and scale the loss value. This would be much cleaner and rely on fewer assumptions overall, and your tests would be much simpler (though unfortunately you'd have to redo them to use SupervisedTrainer). Does that sound like a way forward with this feature?

aymuos15 added 2 commits March 9, 2026 09:25
Replace external GradientAccumulation callable class with a native
`accumulation_steps` constructor parameter on SupervisedTrainer, per
reviewer feedback. This eliminates monkey-patching of optimizer/loss/scaler
internals and instead uses simple conditionals in `_iteration()`.

Based on feedback from @ericspod

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
@aymuos15
Copy link
Contributor Author

aymuos15 commented Mar 9, 2026

@ericspod Apologies, I think I misunderstood the previous recommendations over there. I have now reverted to what you said. Thank you very much. Please let me know if this is aligned now.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tests/engines/test_gradient_accumulation.py (1)

25-25: ⚠️ Potential issue | 🟡 Minor

Add (True,) to invalid input coverage.

isinstance(True, int) is True in Python. If the implementation doesn't guard against bools, True would pass validation but is semantically invalid.

-INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",)]
+INVALID_ACCUMULATION_STEPS = [(0,), (-1,), (2.5,), ("2",), (True,)]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` at line 25, Update the test
invalid input coverage by adding the tuple (True,) to the
INVALID_ACCUMULATION_STEPS sequence so boolean True (which passes
isinstance(True, int)) is treated as an invalid accumulation-steps input; modify
the constant INVALID_ACCUMULATION_STEPS in
tests/engines/test_gradient_accumulation.py to include (True,) alongside (0,),
(-1,), (2.5,), and ("2",).
🧹 Nitpick comments (1)
tests/engines/test_gradient_accumulation.py (1)

94-95: Consider strict=True for zip calls.

Catches length mismatches between model parameters. Applies to lines 94, 127, 161, 196.

-        for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters()):
+        for p_test, p_ref in zip(test_model.parameters(), ref_model.parameters(), strict=True):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/engines/test_gradient_accumulation.py` around lines 94 - 95, The
parameter-comparison loops use zip(test_model.parameters(),
ref_model.parameters()) which silently ignores length mismatches; change these
zip calls to zip(..., strict=True) in the loops that compare p_test and p_ref
(where torch.testing.assert_close(p_test.data, p_ref.data) is called) and the
similar occurrences at the other two comparison sites so mismatched parameter
counts raise immediately.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@monai/engines/trainer.py`:
- Around line 171-172: The validation for accumulation_steps in trainer.py
accepts bools because bool is a subclass of int; update the check around the
accumulation_steps validation (the if that currently uses
isinstance(accumulation_steps, int) and accumulation_steps < 1) to explicitly
reject booleans (e.g., ensure accumulation_steps is an int but not a bool, or
use type(accumulation_steps) is int) and still enforce accumulation_steps >= 1
so True/False cannot slip through; adjust the ValueError path accordingly.

---

Duplicate comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Line 25: Update the test invalid input coverage by adding the tuple (True,) to
the INVALID_ACCUMULATION_STEPS sequence so boolean True (which passes
isinstance(True, int)) is treated as an invalid accumulation-steps input; modify
the constant INVALID_ACCUMULATION_STEPS in
tests/engines/test_gradient_accumulation.py to include (True,) alongside (0,),
(-1,), (2.5,), and ("2",).

---

Nitpick comments:
In `@tests/engines/test_gradient_accumulation.py`:
- Around line 94-95: The parameter-comparison loops use
zip(test_model.parameters(), ref_model.parameters()) which silently ignores
length mismatches; change these zip calls to zip(..., strict=True) in the loops
that compare p_test and p_ref (where torch.testing.assert_close(p_test.data,
p_ref.data) is called) and the similar occurrences at the other two comparison
sites so mismatched parameter counts raise immediately.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: fa94b287-b3d6-416f-9eed-3882a19399ca

📥 Commits

Reviewing files that changed from the base of the PR and between 53c5dc5 and 61d2a9f.

📒 Files selected for processing (2)
  • monai/engines/trainer.py
  • tests/engines/test_gradient_accumulation.py

Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add gradient accumulation functionality to SupervisedTrainer

2 participants